"""
VITAL MEASURE
"""
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
import numpy as np
from .utils import process_text
import evaluate
from seqeval.metrics import f1_score as entity_score, precision_score, recall_score
from difflib import SequenceMatcher
from sklearn.metrics import f1_score, matthews_corrcoef 
from bart_score import BARTScorer
from bs4 import BeautifulSoup as bs
from bs4 import NavigableString, Tag
from glob import glob
import spacy

import evaluate
import textstat

_CITATION = """
"""


def align_tokens(seq1, seq2):
    aligned_seq = []
    i, j = 0, 0
    while i < len(seq1) and j < len(seq2):
        token1, tag1 = seq1[i]
        token2, tag2 = seq2[j]
        if token1 == token2:
            aligned_seq.append((token1, tag1, tag2))
            i += 1
            j += 1
        else:
            if token2.startswith(token1):
                aligned_seq.append((token1, tag1, tag2))
                seq2[j] = (token2[len(token1):], "I-" + tag2[2:] if tag2.startswith("B-") else tag2)
                i += 1
            elif token1.startswith(token2):
                aligned_seq.append((token2, tag1, tag2))
                seq1[i] = (token1[len(token2):], "I-" + tag1[2:] if tag1.startswith("B-") else tag1)
                j += 1
            else:
                i += 1
                j += 1
    return aligned_seq


class Classification(Task):
    CALCULATE_MCC = False
    LOWER_CASE = True
    FIRST_LETTER = False
    VERSION = 1
    EVAL_LAST_TURN = True

    def reformulate_turn_req(self, req, turn_request, turn):
        return req

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        cont_request = rf.greedy_until(ctx, {"until": None})
        return cont_request

    def doc_to_decontamination_query(self, doc):
        return doc["text"]

    def doc_to_text(self, doc, prompt=False):
        # TODO: Format the query prompt portion of the document example.
        return doc["query"]

    def doc_to_target(self, doc):
        # TODO: Format the query prompt portion of the document example.
        return doc["answer"]

    def process_results(self, doc, results):
        gold: str = doc["choices"][doc["gold"]]
        if self.LOWER_CASE:
            gold = gold.lower()
        ini_result = results[0].strip()
        if self.LOWER_CASE:
            ini_result = ini_result.lower()

        if self.FIRST_LETTER:
            ini_result = ini_result.strip()[0]

        result = None
        for choice in doc["choices"]:
            if self.LOWER_CASE:
                choice = choice.lower()
            if choice in ini_result:
                result = choice
                break
        if result is None:
            result = "missing"

        acc = 1.0 if gold == result else 0.0

        results = {
            "acc": acc,
            "missing": int(result == "missing"),
            "f1": (result, gold),
            "macro_f1": (result, gold),
        }

        if self.CALCULATE_MCC:
            results["mcc"] = (result, gold)

        return results

    def higher_is_better(self):
        metrics = {
            "acc": True,
            "f1": True,
            "macro_f1": True,
            "missing": False,
        }
        if self.CALCULATE_MCC:
            metrics["mcc"] = True
        return metrics

    def weighted_f1(self, items):
        preds, golds = zip(*items)
        labels = list(set(golds))
        preds = np.array(preds)
        golds = np.array(golds)
        f1 = f1_score(golds, preds, average="weighted", labels=labels)
        return f1

    def macro_f1(self, items):
        preds, golds = zip(*items)
        labels = list(set(golds))
        preds = np.array(preds)
        golds = np.array(golds)
        f1 = f1_score(golds, preds, average="macro", labels=labels)
        return f1

    def matthews_corrcoef(self, items):
        preds, golds = zip(*items)
        labels = {label: i for i, label in enumerate(list(set(golds)))}
        preds = [labels.get(pred, -1) for pred in preds]
        golds = [labels.get(gold, -1) for gold in golds]
        return matthews_corrcoef(golds, preds)

    def aggregation(self):
        metrics = {
            "acc": mean,
            "missing": mean,
            "f1": self.weighted_f1,
            "macro_f1": self.macro_f1,
        }
        if self.CALCULATE_MCC:
            metrics["mcc"] = self.matthews_corrcoef
        return metrics


class SequentialLabeling(Task):
    VERSION = 1
    DATASET_NAME = None
    LMAP = {"O": 0}
    EVAL_LAST_TURN = True

    def reformulate_turn_req(self, req, turn_request, turn):
        return req

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def doc_to_text(self, doc):
        # TODO: Format the query prompt portion of the document example.
        return doc["query"]

    def doc_to_target(self, doc):
        return "\nAnswer: " + doc["answer"]

    def process_results(self, doc, results):
        return {
            "entity_f1": (doc["label"], results[0], doc["token"]),
            "f1": (doc["label"], results[0], doc["token"]),
        }

    def higher_is_better(self):
        return {
            "f1": True,
            "entity_f1": True,
        }

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        cont_request = rf.greedy_until(ctx, {"until": None})
        return cont_request

    def process_result(self, pred, gold, tokens):
        format_pred = ["O"] * len(gold)
        for index, pre in enumerate(pred.split("\n")[: len(tokens)]):
            try:
                word, label = pre.split(":")
            except:
                continue
            if word == tokens[index] and label in self.LMAP.keys():
                format_pred[index] = label
        return format_pred

    def entity_f1(self, items):
        golds, preds, tokens = zip(*items)

        list_preds = [
            self.process_result(pred, gold, token)
            for pred, gold, token in zip(preds, golds, tokens)
        ]
        f1 = entity_score(golds, list_preds)
        return f1

    def process_label_result(self, pred, gold, tokens):
        format_pred = [-1] * len(gold)
        for index, pre in enumerate(pred.split("\n")[: len(tokens)]):
            try:
                word, label = pre.split(":")
            except:
                continue
            if word == tokens[index]:
                format_pred[index] = self.LMAP.get(label, -1)
        return format_pred

    def label_f1(self, items):
        golds, preds, tokens = zip(*items)

        list_preds = [
            self.process_label_result(pred, gold, token)
            for pred, gold, token in zip(preds, golds, tokens)
        ]
        list_preds = [item for sublist in list_preds for item in sublist]
        golds = [self.LMAP[item] for sublist in golds for item in sublist]
        f1 = f1_score(golds, list_preds, average="weighted")
        return f1

    def aggregation(self):
        return {
            "entity_f1": self.entity_f1,
            "f1": self.label_f1,
        }


class NER(SequentialLabeling):
    py_nlp = spacy.load("en_core_web_lg")
    entity_list = ["problem", "treatment", "test", "drug"]

    def html2bio(self, prediction_html, label_html):
        # Function to parse HTML and extract words and BIO tags
        def parse_html(html):
            soup = bs(html, "html.parser")
            words = []
            bio_tags = []
            for child in soup.children:
                if isinstance(child, NavigableString):
                    for word in child.split():
                        words.append(word)
                        bio_tags.append("O")
                elif isinstance(child, Tag):
                    child_words = [token.text for token in self.py_nlp(child.get_text())]
                    try:
                        entity = child.attrs['class'][0]
                    except:
                        entity = 'O'
                    for i, word in enumerate(child_words):
                        words.append(word)
                        if entity != 'O' and entity in self.entity_list:
                            bio_tags.append(f"B-{entity}" if i == 0 else f"I-{entity}")
                        else:
                            bio_tags.append("O")
            return words, bio_tags

        # Parse prediction and label HTMLs
        prediction_words, prediction_bio = parse_html(prediction_html)
        label_words, label_bio = parse_html(label_html)

        # Align the sequences
        aligned_sequences = align_tokens(list(zip(prediction_words, prediction_bio)),
                                         list(zip(label_words, label_bio)))

        # Extracting aligned target and label sequences
        aligned_prediction_sequence = [tag1 for _, tag1, _ in aligned_sequences]
        aligned_label_sequence = [tag2 for _, _, tag2 in aligned_sequences]

        # Returning the aligned sequences
        return aligned_prediction_sequence, aligned_label_sequence

    def process_results(self, doc, results):
        return {
            "precision": (doc["answer"], results[0]),
            "recall": (doc["answer"], results[0]),
            "f1": (doc["answer"], results[0]),
        }

    def higher_is_better(self):
        return {
            "precision": True,
            "recall": True,
            "f1": True,
        }

    def cal_f1(self, items):
        golds, preds = zip(*items)
        list_preds, list_golds = [], []

        for pred, gold in zip(preds, golds):
            pd, gd = self.html2bio(pred, gold)
            list_preds.append(pd)
            list_golds.append(gd)
        f1 = entity_score(list_golds, list_preds)
        return f1

    def cal_precision(self, items):
        golds, preds = zip(*items)

        list_preds, list_golds = [], []

        for pred, gold in zip(preds, golds):
            pd, gd = self.html2bio(pred, gold)
            list_preds.append(pd)
            list_golds.append(gd)
        pre = precision_score(list_golds, list_preds)
        return pre

    def cal_recall(self, items):
        golds, preds = zip(*items)

        list_preds, list_golds = [], []

        for pred, gold in zip(preds, golds):
            pd, gd = self.html2bio(pred, gold)
            list_preds.append(pd)
            list_golds.append(gd)
        rec = recall_score(list_golds, list_preds)
        return rec

    def aggregation(self):
        return {
            "f1": self.cal_f1,
            "precision": self.cal_precision,
            "recall": self.cal_recall,
        }


class AbstractiveSummarization(Task):
    VERSION = 1
    DATASET_NAME = None
    EVAL_LAST_TURN = True
    rouge = evaluate.load("rouge")

    def reformulate_turn_req(self, req, turn_request, turn):
        return req

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def doc_to_text(self, doc):
        # TODO: Format the query prompt portion of the document example.
        return doc["query"]

    def doc_to_target(self, doc):
        return doc["answer"]

    def process_results(self, doc, results):
        return {
            "rouge1": (doc["answer"], results[0]),
            "rouge2": (doc["answer"], results[0]),
            "rougeL": (doc["answer"], results[0]),
            "bert_score_f1": (doc["answer"], results[0]),
            "bart_score": (doc["answer"], results[0]),
        }

    def higher_is_better(self):
        return {
            "rouge1": True,
            "rouge2": True,
            "rougeL": True,
            "bert_score_f1": True,
            "bart_score": True,
        }

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        cont_request = rf.greedy_until(ctx, {"until": None})
        return cont_request

    def rouge_score(self, items):
        golds, preds = zip(*items)
        results = self.rouge.compute(predictions=preds, references=golds)
        return results

    def rouge1(self, items):
        results = self.rouge_score(items)
        return results["rouge1"]

    def rouge2(self, items):
        results = self.rouge_score(items)
        return results["rouge2"]

    def rougeL(self, items):
        results = self.rouge_score(items)
        return results["rougeL"]

    def bert_score(self, items):
        if getattr(self, "_cache_bertscore", None) is None:
            golds, preds = zip(*items)
            bertscore = evaluate.load("evaluate-metric/bertscore")
            self._cache_bertscore = bertscore.compute(
                predictions=preds,
                references=golds,
                model_type="bert-base-multilingual-cased",
            )
            return self._cache_bertscore
        else:
            return self._cache_bertscore

    def bert_score_f1(self, items):
        res = self.bert_score(items)
        return sum(res["f1"]) / len(res["f1"])

    def bart_score(self, items):
        golds, preds = zip(*items)
        bart_scorer = BARTScorer(device="cuda", checkpoint="facebook/bart-large-cnn")
        bart_scorer.load(path="src/metrics/BARTScore/bart_score.pth")
        res = bart_scorer.score(srcs=preds, tgts=golds, batch_size=8)
        return sum(res) / len(res)

    def aggregation(self):
        return {
            "rouge1": self.rouge1,
            "rouge2": self.rouge2,
            "rougeL": self.rougeL,
            "bert_score_f1": self.bert_score_f1,
            "bart_score": self.bart_score,
        }


class QA(Task):
    VERSION = 1
    DATASET_NAME = None
    EVAL_LAST_TURN = True

    def reformulate_turn_req(self, req, turn_request, turn):
        return req

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        return self.dataset["test"]

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["text"]

    def doc_to_text(self, doc, prompt_only=False):
        # TODO: Format the query prompt portion of the document example.
        return doc["query"]

    def construct_requests(self, doc, ctx):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        cont_request = rf.greedy_until(ctx, {"until": None})
        return cont_request

    def doc_to_target(self, doc):
        return doc["answer"]

    def process_results(self, doc, results):
        gold = doc["answer"]

        acc = 1.0 if results[0].strip() == gold else 0.0

        return {
            "acc": acc,
        }

    def higher_is_better(self):
        return {
            "acc": True,
        }

    def aggregation(self):
        return {
            "acc": mean,
        }


class NormQA(QA):
    def process_results(self, doc, results):
        gold = doc["answer"]

        def normalize_text(s):
            import string, re
            def white_space_fix(text):
                return " ".join(text.split())
            def remove_punc(text):
                exclude = set(string.punctuation)
                return " ".join(ch for ch in text if ch not in exclude)
            def lower(text):
                return text.lower()
            return white_space_fix(remove_punc(lower(s)))

        result = results[0]
        gold = normalize_text(gold)
        result = normalize_text(result)

        acc = 1.0 if result == gold else 0.0
        f1 = self.compute_f1(result, gold)

        return {
            "acc": acc,
            "f1": f1,
        }

    def compute_f1(self, prediction, truth):
        pred_tokens = prediction.split()
        truth_tokens = truth.split()
        # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
        if len(pred_tokens) == 0 or len(truth_tokens) == 0:
            return int(pred_tokens == truth_tokens)
        common_tokens = set(pred_tokens) & set(truth_tokens)
        # if there are no common tokens then f1 = 0
        if len(common_tokens) == 0:
            return 0
        prec = len(common_tokens) / len(pred_tokens)
        rec = len(common_tokens) / len(truth_tokens)
        return 2 * (prec * rec) / (prec + rec)

    def higher_is_better(self):
        return {
            "acc": True,
            "f1": True,
        }

    def aggregation(self):
        return {
            "acc": mean,
            "f1": mean,
        }


class PUBMEDQA(Classification):
    DATASET_PATH = "clinicalnlplab/pubmedqa_test"


class MedQA(Classification):
    DATASET_PATH = "clinicalnlplab/medQA_test"
    LOWER_CASE = False


class MedMCQA(Classification):
    DATASET_PATH = "clinicalnlplab/medMCQA_test"
    LOWER_CASE = False


class EmrQA(NormQA):
    DATASET_PATH = "clinicalnlplab/emrqa_test"


class I2B2(NER):
    DATASET_PATH = "local:/I2B2_test"


class DDI2013(Classification):
    DATASET_PATH = "clinicalnlplab/DDI2013_test"

    def process_results(self, doc, results):
        gold: str = doc["choices"][doc["gold"]]
        ini_result = results[0].strip()

        result = []
        for choice in doc["choices"]:
            if choice in ini_result:
                result.append(choice)
        if len(result) != 1:
            result = "missing"
        else:
            result = result[0]

        acc = 1.0 if gold == result else 0.0

        results = {
            "acc": acc,
            "missing": int(result == "missing"),
            "f1": (result, gold),
            "macro_f1": (result, gold),
        }

        if self.CALCULATE_MCC:
            results["mcc"] = (result, gold)

        return results


class HoC(Classification):
    DATASET_PATH = "clinicalnlplab/HoC_test"
    CHOICES = [ "sustaining proliferative signaling", "evading growth suppressors", "resisting cell death", "enabling replicative immortality", "inducing angiogenesis", "activating invasion and metastasis", "genomic instability and mutation", "tumor promoting inflammation", "cellular energetics", "avoiding immune destruction" ]

    def process_results(self, doc, results):
        gold = doc["gold"]
        ini_result = results[0].strip()

        result = [0] * len(self.CHOICES)
        for index, choice in enumerate(doc["choices"]):
            if choice in ini_result:
                result[index] = 1

        results = {}

        for index, choice in enumerate(doc["choices"]):
            results[f"{choice}_f1"] = (result[index], gold[index])

        results["macro_f1"] = (result, gold)

        return results

    def higher_is_better(self):
        results = {}
        for choice in self.CHOICES:
            results[f"{choice}_f1"] = True
        results["micro_f1"] = True
        return results

    def binary_f1(self, items):
        preds, golds = zip(*items)
        labels = list(set(golds))
        preds = np.array(preds)
        golds = np.array(golds)
        f1 = f1_score(golds, preds, average="binary", labels=labels)
        return f1

    def macro_label_f1(self, items):
        preds, golds = zip(*items)
        preds = np.array(preds)
        golds = np.array(golds)
        return f1_score(preds, golds, average='macro')

    def aggregation(self):
        results = {}
        for choice in self.CHOICES:
            results[f"{choice}_f1"] = self.binary_f1
        results["macro_f1"] = self.macro_label_f1
        return results


class MTSample(Classification):
    DATASET_PATH = "clinicalnlplab/MTSample_test"


class PubmedSum(AbstractiveSummarization):
    DATASET_PATH = "clinicalnlplab/PubmedSumm_test"


class MimicSum(AbstractiveSummarization):
    DATASET_PATH = "clinicalnlplab/MIMIC_SUM_test"


class BioNLI(Classification):
    DATASET_PATH = "clinicalnlplab/BioNLI_test"


class MedNLI(Classification):
    DATASET_PATH = "clinicalnlplab/MedNLI_test"
